r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter

To support these two classes, in `./_utils` we define many utility methods and
functions to be run in multiprocessing. E.g., the data loading worker loop is
in `./_utils/worker.py`.
"""

import threading
import itertools
import warnings
import multiprocessing as python_multiprocessing
import torch
import torch.multiprocessing as multiprocessing
from torch._utils import ExceptionWrapper
from torch._six import queue, string_classes
from torch.utils.data.dataset import IterableDataset
from torch.utils.data import Sampler, SequentialSampler, RandomSampler, BatchSampler
from torch.utils.data import _utils

from .my_data_worker import worker_loop

__all__ = ['MyDataLoader']

get_worker_info = _utils.worker.get_worker_info

# This function used to be defined in this file. However, it was moved to
# _utils/collate.py. Although it is rather hard to access this from user land
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
# probably is user code out there using it. This aliasing maintains BC in this
# aspect.
default_collate = _utils.collate.default_collate


class _DatasetKind(object):
	Map = 0
	Iterable = 1

	@staticmethod
	def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
		if kind == _DatasetKind.Map:
			return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
		else:
			return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)


class _InfiniteConstantSampler(Sampler):
	r"""Analogous to ``itertools.repeat(None, None)``.
	Used as sampler for :class:`~torch.utils.data.IterableDataset`.

	Arguments:
		data_source (Dataset): dataset to sample from
	"""

	def __init__(self):
		super(_InfiniteConstantSampler, self).__init__(None)

	def __iter__(self):
		while True:
			yield None


class MyDataLoader(object):
	r"""
	Data loader. Combines a dataset and a sampler, and provides an iterable over
	the given dataset.

	The :class:`~torch.utils.data.DataLoader` supports both map-style and
	iterable-style datasets with single- or multi-process loading, customizing
	loading order and optional automatic batching (collation) and memory pinning.

	See :py:mod:`torch.utils.data` documentation page for more details.

	Arguments:
		dataset (Dataset): dataset from which to load the data.
		batch_size (int, optional): how many samples per batch to load
			(default: ``1``).
		shuffle (bool, optional): set to ``True`` to have the data reshuffled
			at every epoch (default: ``False``).
		sampler (Sampler, optional): defines the strategy to draw samples from
			the dataset. If specified, :attr:`shuffle` must be ``False``.
		batch_sampler (Sampler, optional): like :attr:`sampler`, but returns a batch of
			indices at a time. Mutually exclusive with :attr:`batch_size`,
			:attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
		num_workers (int, optional): how many subprocesses to use for data
			loading. ``0`` means that the data will be loaded in the main process.
			(default: ``0``)
		collate_fn (callable, optional): merges a list of samples to form a
			mini-batch of Tensor(s).  Used when using batched loading from a
			map-style dataset.
		pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
			into CUDA pinned memory before returning them.  If your data elements
			are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
			see the example below.
		drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
			if the dataset size is not divisible by the batch size. If ``False`` and
			the size of dataset is not divisible by the batch size, then the last batch
			will be smaller. (default: ``False``)
		timeout (numeric, optional): if positive, the timeout value for collecting a batch
			from workers. Should always be non-negative. (default: ``0``)
		worker_init_fn (callable, optional): If not ``None``, this will be called on each
			worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
			input, after seeding and before data loading. (default: ``None``)


	.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
				 cannot be an unpicklable object, e.g., a lambda function. See
				 :ref:`multiprocessing-best-practices` on more details related
				 to multiprocessing in PyTorch.

	.. note:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
			  When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
			  ``len(dataset)`` (if implemented) is returned instead, regardless
			  of multi-process loading configurations, because PyTorch trust
			  user :attr:`dataset` code in correctly handling multi-process
			  loading to avoid duplicate data. See `Dataset Types`_ for more
			  details on these two types of datasets and how
			  :class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_.
	"""

	__initialized = False

	def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
	             batch_sampler=None, num_workers=0, collate_fn=None,
	             pin_memory=False, drop_last=False, timeout=0,
	             worker_init_fn=None, multiprocessing_context=None):
		torch._C._log_api_usage_once("python.data_loader")

		if num_workers < 0:
			raise ValueError('num_workers option should be non-negative; '
			                 'use num_workers=0 to disable multiprocessing.')

		if timeout < 0:
			raise ValueError('timeout option should be non-negative')

		self.dataset = dataset
		self.num_workers = num_workers
		self.pin_memory = pin_memory
		self.timeout = timeout
		self.worker_init_fn = worker_init_fn
		self.multiprocessing_context = multiprocessing_context

		# Arg-check dataset related before checking samplers because we want to
		# tell users that iterable-style datasets are incompatible with custom
		# samplers first, so that they don't learn that this combo doesn't work
		# after spending time fixing the custom sampler errors.
		if isinstance(dataset, IterableDataset):
			self._dataset_kind = _DatasetKind.Iterable
			# NOTE [ Custom Samplers and `IterableDataset` ]
			#
			# `IterableDataset` does not support custom `batch_sampler` or
			# `sampler` since the key is irrelevant (unless we support
			# generator-style dataset one day...).
			#
			# For `sampler`, we always create a dummy sampler. This is an
			# infinite sampler even when the dataset may have an implemented
			# finite `__len__` because in multi-process data loading, naive
			# settings will return duplicated data (which may be desired), and
			# thus using a sampler with length matching that of dataset will
			# cause data lost (you may have duplicates of the first couple
			# batches, but never see anything afterwards). Therefore,
			# `Iterabledataset` always uses an infinite sampler, an instance of
			# `_InfiniteConstantSampler` defined above.
			#
			# A custom `batch_sampler` essentially only controls the batch size.
			# However, it is unclear how useful it would be since an iterable-style
			# dataset can handle that within itself. Moreover, it is pointless
			# in multi-process data loading as the assignment order of batches
			# to workers is an implementation detail so users can not control
			# how to batchify each worker's iterable. Thus, we disable this
			# option. If this turns out to be useful in future, we can re-enable
			# this, and support custom samplers that specify the assignments to
			# specific workers.
			if shuffle is not False:
				raise ValueError(
					"DataLoader with IterableDataset: expected unspecified "
					"shuffle option, but got shuffle={}".format(shuffle))
			elif sampler is not None:
				# See NOTE [ Custom Samplers and IterableDataset ]
				raise ValueError(
					"DataLoader with IterableDataset: expected unspecified "
					"sampler option, but got sampler={}".format(sampler))
			elif batch_sampler is not None:
				# See NOTE [ Custom Samplers and IterableDataset ]
				raise ValueError(
					"DataLoader with IterableDataset: expected unspecified "
					"batch_sampler option, but got batch_sampler={}".format(batch_sampler))
		else:
			self._dataset_kind = _DatasetKind.Map

		if sampler is not None and shuffle:
			raise ValueError('sampler option is mutually exclusive with '
			                 'shuffle')

		if batch_sampler is not None:
			# auto_collation with custom batch_sampler
			if batch_size != 1 or shuffle or sampler is not None or drop_last:
				raise ValueError('batch_sampler option is mutually exclusive '
				                 'with batch_size, shuffle, sampler, and '
				                 'drop_last')
			batch_size = None
			drop_last = False
		elif batch_size is None:
			# no auto_collation
			if shuffle or drop_last:
				raise ValueError('batch_size=None option disables auto-batching '
				                 'and is mutually exclusive with '
				                 'shuffle, and drop_last')

		if sampler is None:  # give default samplers
			if self._dataset_kind == _DatasetKind.Iterable:
				# See NOTE [ Custom Samplers and IterableDataset ]
				sampler = _InfiniteConstantSampler()
			else:  # map-style
				if shuffle:
					sampler = RandomSampler(dataset)
				else:
					sampler = SequentialSampler(dataset)

		if batch_size is not None and batch_sampler is None:
			# auto_collation without custom batch_sampler
			batch_sampler = BatchSampler(sampler, batch_size, drop_last)

		self.batch_size = batch_size
		self.drop_last = drop_last
		self.sampler = sampler
		self.batch_sampler = batch_sampler

		if collate_fn is None:
			if self._auto_collation:
				collate_fn = _utils.collate.default_collate
			else:
				collate_fn = _utils.collate.default_convert

		self.collate_fn = collate_fn
		self.__initialized = True
		self._IterableDataset_len_called = None  # See NOTE [ IterableDataset and __len__ ]

	@property
	def multiprocessing_context(self):
		return self.__multiprocessing_context

	@multiprocessing_context.setter
	def multiprocessing_context(self, multiprocessing_context):
		if multiprocessing_context is not None:
			if self.num_workers > 0:
				if not multiprocessing._supports_context:
					raise ValueError('multiprocessing_context relies on Python >= 3.4, with '
					                 'support for different start methods')

				if isinstance(multiprocessing_context, string_classes):
					valid_start_methods = multiprocessing.get_all_start_methods()
					if multiprocessing_context not in valid_start_methods:
						raise ValueError(
							('multiprocessing_context option '
							 'should specify a valid start method in {}, but got '
							 'multiprocessing_context={}').format(valid_start_methods, multiprocessing_context))
					multiprocessing_context = multiprocessing.get_context(multiprocessing_context)

				if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
					raise ValueError(('multiprocessing_context option should be a valid context '
					                  'object or a string specifying the start method, but got '
					                  'multiprocessing_context={}').format(multiprocessing_context))
			else:
				raise ValueError(('multiprocessing_context can only be used with '
				                  'multi-process loading (num_workers > 0), but got '
				                  'num_workers={}').format(self.num_workers))

		self.__multiprocessing_context = multiprocessing_context

	def __setattr__(self, attr, val):
		if self.__initialized and attr in ('batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset'):
			raise ValueError('{} attribute should not be set after {} is '
			                 'initialized'.format(attr, self.__class__.__name__))

		super(MyDataLoader, self).__setattr__(attr, val)

	def __iter__(self):
		if self.num_workers == 0:
			return _SingleProcessDataLoaderIter(self)
		else:
			return _MultiProcessingDataLoaderIter(self)

	@property
	def _auto_collation(self):
		return self.batch_sampler is not None

	@property
	def _index_sampler(self):
		# The actual sampler used for generating indices for `_DatasetFetcher`
		# (see _utils/fetch.py) to read data at each time. This would be
		# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
		# We can't change `.sampler` and `.batch_sampler` attributes for BC
		# reasons.
		if self._auto_collation:
			return self.batch_sampler
		else:
			return self.sampler

	def __len__(self):
		if self._dataset_kind == _DatasetKind.Iterable:
			# NOTE [ IterableDataset and __len__ ]
			#
			# For `IterableDataset`, `__len__` could be inaccurate when one naively
			# does multi-processing data loading, since the samples will be duplicated.
			# However, no real use case should be actually using that behavior, so
			# it should count as a user error. We should generally trust user
			# code to do the proper thing (e.g., configure each replica differently
			# in `__iter__`), and give us the correct `__len__` if they choose to
			# implement it (this will still throw if the dataset does not implement
			# a `__len__`).
			#
			# To provide a further warning, we track if `__len__` was called on the
			# `DataLoader`, save the returned value in `self._len_called`, and warn
			# if the iterator ends up yielding more than this number of samples.
			length = self._IterableDataset_len_called = len(self.dataset)
			return length
		else:
			return len(self._index_sampler)


class _BaseDataLoaderIter(object):
	def __init__(self, loader):
		self._dataset = loader.dataset
		self._dataset_kind = loader._dataset_kind
		self._IterableDataset_len_called = loader._IterableDataset_len_called
		self._auto_collation = loader._auto_collation
		self._drop_last = loader.drop_last
		self._index_sampler = loader._index_sampler
		self._num_workers = loader.num_workers
		self._pin_memory = loader.pin_memory and torch.cuda.is_available()
		self._timeout = loader.timeout
		self._collate_fn = loader.collate_fn
		self._sampler_iter = iter(self._index_sampler)
		self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
		self._num_yielded = 0

	def __iter__(self):
		return self

	def _next_index(self):
		return next(self._sampler_iter)  # may raise StopIteration

	def _next_data(self):
		raise NotImplementedError

	def __next__(self):
		data = self._next_data()
		self._num_yielded += 1
		if self._dataset_kind == _DatasetKind.Iterable and \
				self._IterableDataset_len_called is not None and \
				self._num_yielded > self._IterableDataset_len_called:
			warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
			            "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
			                                                  self._num_yielded)
			if self._num_workers > 0:
				warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
				             "IterableDataset replica at each worker. Please see "
				             "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
			warnings.warn(warn_msg)
		return data

	next = __next__  # Python 2 compatibility

	def __len__(self):
		return len(self._index_sampler)

	def __getstate__(self):
		# across multiple threads for HOGWILD.
		# Probably the best way to do this is by moving the sample pushing
		# to a separate thread and then just sharing the data queue
		# but signalling the end is tricky without a non-blocking API
		raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)


class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
	def __init__(self, loader):
		super(_SingleProcessDataLoaderIter, self).__init__(loader)
		assert self._timeout == 0
		assert self._num_workers == 0

		self._dataset_fetcher = _DatasetKind.create_fetcher(
			self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

	def _next_data(self):
		index = self._next_index()  # may raise StopIteration
		data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
		if self._pin_memory:
			data = _utils.pin_memory.pin_memory(data)
		return data


class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
	r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""

	# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
	#
	# Preliminary:
	#
	# Our data model looks like this (queues are indicated with curly brackets):
	#
	#                main process                              ||
	#                     |                                    ||
	#               {index_queue}                              ||
	#                     |                                    ||
	#              worker processes                            ||     DATA
	#                     |                                    ||
	#            {worker_result_queue}                         ||     FLOW
	#                     |                                    ||
	#      pin_memory_thread of main process                   ||   DIRECTION
	#                     |                                    ||
	#               {data_queue}                               ||
	#                     |                                    ||
	#                data output                               \/
	#
	# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
	#      `pin_memory=False`.
	#
	#
	# Terminating multiprocessing logic requires very careful design. In
	# particular, we need to make sure that
	#
	#   1. The iterator gracefully exits the workers when its last reference is
	#      gone or it is depleted.
	#
	#      In this case, the workers should be gracefully exited because the
	#      main process may still need to continue to run, and we want cleaning
	#      up code in the workers to be executed (e.g., releasing GPU memory).
	#      Naturally, we implement the shutdown logic in `__del__` of
	#      DataLoaderIterator.
	#
	#      We delay the discussion on the logic in this case until later.
	#
	#   2. The iterator exits the workers when the loader process and/or worker
	#      processes exits normally or with error.
	#
	#      We set all workers and `pin_memory_thread` to have `daemon=True`.
	#
	#      You may ask, why can't we make the workers non-daemonic, and
	#      gracefully exit using the same logic as we have in `__del__` when the
	#      iterator gets deleted (see 1 above)?
	#
	#      First of all, `__del__` is **not** guaranteed to be called when
	#      interpreter exits. Even if it is called, by the time it executes,
	#      many Python core library resources may alreay be freed, and even
	#      simple things like acquiring an internal lock of a queue may hang.
	#      Therefore, in this case, we actually need to prevent `__del__` from
	#      being executed, and rely on the automatic termination of daemonic
	#      children. Thus, we register an `atexit` hook that sets a global flag
	#      `_utils.python_exit_status`. Since `atexit` hooks are executed in the
	#      reverse order of registration, we are guaranteed that this flag is
	#      set before library resources we use are freed. (Hooks freeing those
	#      resources are registered at importing the Python core libraries at
	#      the top of this file.) So in `__del__`, we check if
	#      `_utils.python_exit_status` is set or `None` (freed), and perform
	#      no-op if so.
	#
	#      Another problem with `__del__` is also related to the library cleanup
	#      calls. When a process ends, it shuts the all its daemonic children
	#      down with a SIGTERM (instead of joining them without a timeout).
	#      Simiarly for threads, but by a different mechanism. This fact,
	#      together with a few implementation details of multiprocessing, forces
	#      us to make workers daemonic. All of our problems arise when a
	#      DataLoader is used in a subprocess, and are caused by multiprocessing
	#      code which looks more or less like this:
	#
	#          try:
	#              your_function_using_a_dataloader()
	#          finally:
	#              multiprocessing.util._exit_function()
	#
	#      The joining/termination mentioned above happens inside
	#      `_exit_function()`. Now, if `your_function_using_a_dataloader()`
	#      throws, the stack trace stored in the exception will prevent the
	#      frame which uses `DataLoaderIter` to be freed. If the frame has any
	#      reference to the `DataLoaderIter` (e.g., in a method of the iter),
	#      its  `__del__`, which starts the shutdown procedure, will not be
	#      called. That, in turn, means that workers aren't notified. Attempting
	#      to join in `_exit_function` will then result in a hang.
	#
	#      For context, `_exit_function` is also registered as an `atexit` call.
	#      So it is unclear to me (@ssnl) why this is needed in a finally block.
	#      The code dates back to 2008 and there is no comment on the original
	#      PEP 371 or patch https://bugs.python.org/issue3050 (containing both
	#      the finally block and the `atexit` registration) that explains this.
	#
	#      Another choice is to just shutdown workers with logic in 1 above
	#      whenever we see an error in `next`. This isn't ideal because
	#        a. It prevents users from using try-catch to resume data loading.
	#        b. It doesn't prevent hanging if users have references to the
	#           iterator.
	#
	#   3. All processes exit if any of them die unexpectedly by fatal signals.
	#
	#      As shown above, the workers are set as daemonic children of the main
	#      process. However, automatic cleaning-up of such child processes only
	#      happens if the parent process exits gracefully (e.g., not via fatal
	#      signals like SIGKILL). So we must ensure that each process will exit
	#      even the process that should send/receive data to/from it were
	#      killed, i.e.,
	#
	#        a. A process won't hang when getting from a queue.
	#
	#           Even with carefully designed data dependencies (i.e., a `put()`
	#           always corresponding to a `get()`), hanging on `get()` can still
	#           happen when data in queue is corrupted (e.g., due to
	#           `cancel_join_thread` or unexpected exit).
	#
	#           For child exit, we set a timeout whenever we try to get data
	#           from `data_queue`, and check the workers' status on each timeout
	#           and error.
	#           See `_DataLoaderiter._get_batch()` and
	#           `_DataLoaderiter._try_get_data()` for details.
	#
	#           Additionally, for child exit on non-Windows platforms, we also
	#           register a SIGCHLD handler (which is supported on Windows) on
	#           the main process, which checks if any of the workers fail in the
	#           (Python) handler. This is more efficient and faster in detecting
	#           worker failures, compared to only using the above mechanism.
	#           See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
	#
	#           For `.get()` calls where the sender(s) is not the workers, we
	#           guard them with timeouts, and check the status of the sender
	#           when timeout happens:
	#             + in the workers, the `_utils.worker.ManagerWatchdog` class
	#               checks the status of the main process.
	#             + if `pin_memory=True`, when getting from `pin_memory_thread`,
	#               check `pin_memory_thread` status periodically until `.get()`
	#               returns or see that `pin_memory_thread` died.
	#
	#        b. A process won't hang when putting into a queue;
	#
	#           We use `mp.Queue` which has a separate background thread to put
	#           objects from an unbounded buffer array. The background thread is
	#           daemonic and usually automatically joined when the process
	#           exits.
	#
	#           However, in case that the receiver has ended abruptly while
	#           reading from the pipe, the join will hang forever. Therefore,
	#           for both `worker_result_queue` (worker -> main process/pin_memory_thread)
	#           and each `index_queue` (main process -> worker), we use
	#           `q.cancel_join_thread()` in sender process before any `q.put` to
	#           prevent this automatic join.
	#
	#           Moreover, having all queues called `cancel_join_thread` makes
	#           implementing graceful shutdown logic in `__del__` much easier.
	#           It won't need to get from any queue, which would also need to be
	#           guarded by periodic status checks.
	#
	#           Nonetheless, `cancel_join_thread` must only be called when the
	#           queue is **not** going to be read from or write into by another
	#           process, because it may hold onto a lock or leave corrupted data
	#           in the queue, leading other readers/writers to hang.
	#
	#           `pin_memory_thread`'s `data_queue` is a `queue.Queue` that does
	#           a blocking `put` if the queue is full. So there is no above
	#           problem, but we do need to wrap the `put` in a loop that breaks
	#           not only upon success, but also when the main process stops
	#           reading, i.e., is shutting down.
	#
	#
	# Now let's get back to 1:
	#   how we gracefully exit the workers when the last reference to the
	#   iterator is gone.
	#
	# To achieve this, we implement the following logic along with the design
	# choices mentioned above:
	#
	# `workers_done_event`:
	#   A `multiprocessing.Event` shared among the main process and all worker
	#   processes. This is used to signal the workers that the iterator is
	#   shutting down. After it is set, they will not send processed data to
	#   queues anymore, and only wait for the final `None` before exiting.
	#   `done_event` isn't strictly needed. I.e., we can just check for `None`
	#   from the input queue, but it allows us to skip wasting resources
	#   processing data if we are already shutting down.
	#
	# `pin_memory_thread_done_event`:
	#   A `threading.Event` for a similar purpose to that of
	#   `workers_done_event`, but is for the `pin_memory_thread`. The reason
	#   that separate events are needed is that `pin_memory_thread` reads from
	#   the output queue of the workers. But the workers, upon seeing that
	#   `workers_done_event` is set, only wants to see the final `None`, and is
	#   not required to flush all data in the output queue (e.g., it may call
	#   `cancel_join_thread` on that queue if its `IterableDataset` iterator
	#   happens to exhaust coincidentally, which is out of the control of the
	#   main process). Thus, since we will exit `pin_memory_thread` before the
	#   workers (see below), two separete events are used.
	#
	# NOTE: In short, the protocol is that the main process will set these
	#       `done_event`s and then the corresponding processes/threads a `None`,
	#       and that they may exit at any time after receiving the `None`.
	#
	# NOTE: Using `None` as the final signal is valid, since normal data will
	#       always be a 2-tuple with the 1st element being the index of the data
	#       transferred (different from dataset index/key), and the 2nd being
	#       either the dataset key or the data sample (depending on which part
	#       of the data model the queue is at).
	#
	# [ worker processes ]
	#   While loader process is alive:
	#     Get from `index_queue`.
	#       If get anything else,
	#          Check `workers_done_event`.
	#            If set, continue to next iteration
	#                    i.e., keep getting until see the `None`, then exit.
	#            Otherwise, process data:
	#                If is fetching from an `IterableDataset` and the iterator
	#                    is exhausted, send an `_IterableDatasetStopIteration`
	#                    object to signal iteration end. The main process, upon
	#                    receiving such an object, will send `None` to this
	#                    worker and not use the corresponding `index_queue`
	#                    anymore.
	#       If timed out,
	#          No matter `workers_done_event` is set (still need to see `None`)
	#          or not, must continue to next iteration.
	#   (outside loop)
	#   If `workers_done_event` is set,  (this can be False with `IterableDataset`)
	#     `data_queue.cancel_join_thread()`.  (Everything is ending here:
	#                                          main process won't read from it;
	#                                          other workers will also call
	#                                          `cancel_join_thread`.)
	#
	# [ pin_memory_thread ]
	#   # No need to check main thread. If this thread is alive, the main loader
	#   # thread must be alive, because this thread is set as daemonic.
	#   While `pin_memory_thread_done_event` is not set:
	#     Get from `index_queue`.
	#       If timed out, continue to get in the next iteration.
	#       Otherwise, process data.
	#       While `pin_memory_thread_done_event` is not set:
	#         Put processed data to `data_queue` (a `queue.Queue` with blocking put)
	#         If timed out, continue to put in the next iteration.
	#         Otherwise, break, i.e., continuing to the out loop.
	#
	#   NOTE: we don't check the status of the main thread because
	#           1. if the process is killed by fatal signal, `pin_memory_thread`
	#              ends.
	#           2. in other cases, either the cleaning-up in __del__ or the
	#              automatic exit of daemonic thread will take care of it.
	#              This won't busy-wait either because `.get(timeout)` does not
	#              busy-wait.
	#
	# [ main process ]
	#   In the DataLoader Iter's `__del__`
	#     b. Exit `pin_memory_thread`
	#          i.   Set `pin_memory_thread_done_event`.
	#          ii   Put `None` in `worker_result_queue`.
	#          iii. Join the `pin_memory_thread`.
	#          iv.  `worker_result_queue.cancel_join_thread()`.
	#
	#     c. Exit the workers.
	#          i.   Set `workers_done_event`.
	#          ii.  Put `None` in each worker's `index_queue`.
	#          iii. Join the workers.
	#          iv.  Call `.cancel_join_thread()` on each worker's `index_queue`.
	#
	#        NOTE: (c) is better placed after (b) because it may leave corrupted
	#              data in `worker_result_queue`, which `pin_memory_thread`
	#              reads from, in which case the `pin_memory_thread` can only
	#              happen at timeing out, which is slow. Nonetheless, same thing
	#              happens if a worker is killed by signal at unfortunate times,
	#              but in other cases, we are better off having a non-corrupted
	#              `worker_result_queue` for `pin_memory_thread`.
	#
	#   NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
	#         can be omitted
	#
	# NB: `done_event`s isn't strictly needed. E.g., we can just check for
	#     `None` from `index_queue`, but it allows us to skip wasting resources
	#     processing indices already in `index_queue` if we are already shutting
	#     down.

	def __init__(self, loader):
		super(_MultiProcessingDataLoaderIter, self).__init__(loader)

		assert self._num_workers > 0

		if loader.multiprocessing_context is None:
			multiprocessing_context = multiprocessing
		else:
			multiprocessing_context = loader.multiprocessing_context

		self._worker_init_fn = loader.worker_init_fn
		self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
		self._worker_result_queue = multiprocessing_context.Queue()
		self._worker_pids_set = False
		self._shutdown = False
		self._send_idx = 0  # idx of the next task to be sent to workers
		self._rcvd_idx = 0  # idx of the next task to be returned in __next__
		# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
		# map: task idx => - (worker_id,)        if data isn't fetched (outstanding)
		#                  \ (worker_id, data)   if data is already fetched (out-of-order)
		self._task_info = {}
		self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)
		self._workers_done_event = multiprocessing_context.Event()

		self._index_queues = []
		self._workers = []
		# A list of booleans representing whether each worker still has work to
		# do, i.e., not having exhausted its iterable dataset object. It always
		# contains all `True`s if not using an iterable-style dataset
		# (i.e., if kind != Iterable).
		self._workers_status = []
		for i in range(self._num_workers):
			index_queue = multiprocessing_context.Queue()
			# index_queue.cancel_join_thread()
			w = multiprocessing_context.Process(
				target=worker_loop,
				args=(self._dataset_kind, self._dataset, index_queue,
				      self._worker_result_queue, self._workers_done_event,
				      self._auto_collation, self._collate_fn, self._drop_last,
				      self._base_seed + i, self._worker_init_fn, i, self._num_workers))
			w.daemon = True
			# NB: Process.start() actually take some time as it needs to
			#     start a process and pass the arguments over via a pipe.
			#     Therefore, we only add a worker to self._workers list after
			#     it started, so that we do not call .join() if program dies
			#     before it starts, and __del__ tries to join but will get:
			#     AssertionError: can only join a started process.
			w.start()
			self._index_queues.append(index_queue)
			self._workers.append(w)
			self._workers_status.append(True)

		if self._pin_memory:
			self._pin_memory_thread_done_event = threading.Event()
			self._data_queue = queue.Queue()
			pin_memory_thread = threading.Thread(
				target=_utils.pin_memory._pin_memory_loop,
				args=(self._worker_result_queue, self._data_queue,
				      torch.cuda.current_device(),
				      self._pin_memory_thread_done_event))
			pin_memory_thread.daemon = True
			pin_memory_thread.start()
			# Similar to workers (see comment above), we only register
			# pin_memory_thread once it is started.
			self._pin_memory_thread = pin_memory_thread
		else:
			self._data_queue = self._worker_result_queue

		_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))
		_utils.signal_handling._set_SIGCHLD_handler()
		self._worker_pids_set = True

		# prime the prefetch loop
		for _ in range(2 * self._num_workers):
			self._try_put_index()

	def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
		# Tries to fetch data from `self._data_queue` once for a given timeout.
		# This can also be used as inner loop of fetching without timeout, with
		# the sender status as the loop condition.
		#
		# This raises a `RuntimeError` if any worker died expectedly. This error
		# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
		# (only for non-Windows platforms), or the manual check below on errors
		# and timeouts.
		#
		# Returns a 2-tuple:
		#   (bool: whether successfully get data, any: data if successful else None)
		try:
			data = self._data_queue.get(timeout=timeout)
			return (True, data)
		except Exception as e:
			# At timeout and error, we manually check whether any worker has
			# failed. Note that this is the only mechanism for Windows to detect
			# worker failures.
			failed_workers = []
			for worker_id, w in enumerate(self._workers):
				if self._workers_status[worker_id] and not w.is_alive():
					failed_workers.append(w)
					self._shutdown_worker(worker_id)
			if len(failed_workers) > 0:
				pids_str = ', '.join(str(w.pid) for w in failed_workers)
				raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
			if isinstance(e, queue.Empty):
				return (False, None)
			raise

	def _get_data(self):
		# Fetches data from `self._data_queue`.
		#
		# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
		# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
		# in a loop. This is the only mechanism to detect worker failures for
		# Windows. For other platforms, a SIGCHLD handler is also used for
		# worker failure detection.
		#
		# If `pin_memory=True`, we also need check if `pin_memory_thread` had
		# died at timeouts.
		if self._timeout > 0:
			success, data = self._try_get_data(self._timeout)
			if success:
				return data
			else:
				raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
		elif self._pin_memory:
			while self._pin_memory_thread.is_alive():
				success, data = self._try_get_data()
				if success:
					return data
			else:
				# while condition is false, i.e., pin_memory_thread died.
				raise RuntimeError('Pin memory thread exited unexpectedly')
			# In this case, `self._data_queue` is a `queue.Queue`,. But we don't
			# need to call `.task_done()` because we don't use `.join()`.
		else:
			while True:
				success, data = self._try_get_data()
				if success:
					return data

	def _next_data(self):
		while True:
			# If the worker responsible for `self._rcvd_idx` has already ended
			# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
			# we try to advance `self._rcvd_idx` to find the next valid index.
			#
			# This part needs to run in the loop because both the `self._get_data()`
			# call and `_IterableDatasetStopIteration` check below can mark
			# extra worker(s) as dead.
			while self._rcvd_idx < self._send_idx:
				info = self._task_info[self._rcvd_idx]
				worker_id = info[0]
				if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still active
					break
				del self._task_info[self._rcvd_idx]
				self._rcvd_idx += 1
			else:
				# no valid `self._rcvd_idx` is found (i.e., didn't break)
				self._shutdown_workers()
				raise StopIteration

			# Now `self._rcvd_idx` is the batch index we want to fetch

			# Check if the next sample has already been generated
			if len(self._task_info[self._rcvd_idx]) == 2:
				data = self._task_info.pop(self._rcvd_idx)[1]
				return self._process_data(data)

			assert not self._shutdown and self._tasks_outstanding > 0
			idx, data = self._get_data()
			self._tasks_outstanding -= 1

			if self._dataset_kind == _DatasetKind.Iterable:
				# Check for _IterableDatasetStopIteration
				if isinstance(data, _utils.worker._IterableDatasetStopIteration):
					self._shutdown_worker(data.worker_id)
					self._try_put_index()
					continue

			if idx != self._rcvd_idx:
				# store out-of-order samples
				self._task_info[idx] += (data,)
			else:
				del self._task_info[idx]
				return self._process_data(data)

	def _try_put_index(self):
		assert self._tasks_outstanding < 2 * self._num_workers
		try:
			index = self._next_index()
		except StopIteration:
			return
		for _ in range(self._num_workers):  # find the next active worker, if any
			worker_queue_idx = next(self._worker_queue_idx_cycle)
			if self._workers_status[worker_queue_idx]:
				break
		else:
			# not found (i.e., didn't break)
			return

		self._index_queues[worker_queue_idx].put((self._send_idx, index))
		self._task_info[self._send_idx] = (worker_queue_idx,)
		self._tasks_outstanding += 1
		self._send_idx += 1

	def _process_data(self, data):
		self._rcvd_idx += 1
		self._try_put_index()
		if isinstance(data, ExceptionWrapper):
			data.reraise()
		return data

	def _shutdown_worker(self, worker_id):
		# Mark a worker as having finished its work and dead, e.g., due to
		# exhausting an `IterableDataset`. This should be used only when this
		# `_MultiProcessingDataLoaderIter` is going to continue running.

		assert self._workers_status[worker_id]

		# Signal termination to that specific worker.
		q = self._index_queues[worker_id]
		# Indicate that no more data will be put on this queue by the current
		# process.
		q.put(None)

		# Note that we don't actually join the worker here, nor do we remove the
		# worker's pid from C side struct because (1) joining may be slow, and
		# (2) since we don't join, the worker may still raise error, and we
		# prefer capturing those, rather than ignoring them, even though they
		# are raised after the worker has finished its job.
		# Joinning is deferred to `_shutdown_workers`, which it is called when
		# all workers finish their jobs (e.g., `IterableDataset` replicas) or
		# when this iterator is garbage collected.
		self._workers_status[worker_id] = False

	def _shutdown_workers(self):
		# Called when shutting down this `_MultiProcessingDataLoaderIter`.
		# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
		# the logic of this function.
		python_exit_status = _utils.python_exit_status
		if python_exit_status is True or python_exit_status is None:
			# See (2) of the note. If Python is shutting down, do no-op.
			return
		# Normal exit when last reference is gone / iterator is depleted.
		# See (1) and the second half of the note.
		if not self._shutdown:
			self._shutdown = True
			try:
				# Exit `pin_memory_thread` first because exiting workers may leave
				# corrupted data in `worker_result_queue` which `pin_memory_thread`
				# reads from.
				if hasattr(self, '_pin_memory_thread'):
					# Use hasattr in case error happens before we set the attribute.
					self._pin_memory_thread_done_event.set()
					# Send something to pin_memory_thread in case it is waiting
					# so that it can wake up and check `pin_memory_thread_done_event`
					self._worker_result_queue.put((None, None))
					self._pin_memory_thread.join()
					self._worker_result_queue.close()

				# Exit workers now.
				self._workers_done_event.set()
				for worker_id in range(len(self._workers)):
					# Get number of workers from `len(self._workers)` instead of
					# `self._num_workers` in case we error before starting all
					# workers.
					if self._workers_status[worker_id]:
						self._shutdown_worker(worker_id)
				for w in self._workers:
					w.join()
				for q in self._index_queues:
					q.cancel_join_thread()
					q.close()
			finally:
				# Even though all this function does is putting into queues that
				# we have called `cancel_join_thread` on, weird things can
				# happen when a worker is killed by a signal, e.g., hanging in
				# `Event.set()`. So we need to guard this with SIGCHLD handler,
				# and remove pids from the C side data structure only at the
				# end.
				#
				# FIXME: Unfortunately, for Windows, we are missing a worker
				#        error detection mechanism here in this function, as it
				#        doesn't provide a SIGCHLD handler.
				if self._worker_pids_set:
					_utils.signal_handling._remove_worker_pids(id(self))
					self._worker_pids_set = False

	def __del__(self):
		self._shutdown_workers()
